Skip to content

Add Experimental limited sparse embedding bag #8905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from

Conversation

amjames
Copy link
Collaborator

@amjames amjames commented Mar 28, 2025

Users of torch_xla encounter an issue when using the sparse=True option with the Embedding or EmbeddingBag modules.

The gradient for weight is created as a sparse tensor and there is no dispatch registered for the combination of sparse creation APIs w/ the XLA key, or the Sparse functionality key and the XLA backed key used in conjunction.

This is a workaround that can be removed, ported to C++, or extended later:

  • SparseCOOTensor: a tensor subclass implementing the optimization and semantics of upstream SparseTensor. it is Composabile with the XLA device.
  • drop in replacements for F.embedding F.embedding_bag, nn.Embedding, and nn.EmbeddingBag which forward to a custom implementation of the backward and produce the above tensor subclass rather than a native torch sparse tensor.

The tensor subclass may have component tensors indices and values which have xla device without issue.

fixes #8719

@amjames amjames requested a review from ysiraichi March 28, 2025 22:23
@amjames amjames force-pushed the amjames/sparse_embedding_bag branch from aacfd1b to c67588f Compare March 28, 2025 22:27
@tengyifei tengyifei requested a review from qihqi April 10, 2025 22:00
@qihqi
Copy link
Collaborator

qihqi commented Apr 11, 2025

Hi @amjames if this passes test and is finished feel free to publish as PR and merge.

@miladm miladm requested a review from bhavya01 May 6, 2025 17:55
@amjames amjames force-pushed the amjames/sparse_embedding_bag branch from c67588f to c58b539 Compare May 7, 2025 18:24
@amjames amjames force-pushed the amjames/sparse_embedding_bag branch from c58b539 to f394790 Compare May 30, 2025 18:00
@amjames amjames marked this pull request as ready for review May 30, 2025 18:06
@amjames amjames changed the title [Draft] Add Experimental limited sparse embedding bag Add Experimental limited sparse embedding bag May 30, 2025
@amjames
Copy link
Collaborator Author

amjames commented May 30, 2025

A note for reviewers: the failure in xla_op1 shard appears to be unrelated, the tests for the new feature are in xla_op3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for enabling sparse gradients in EmbeddingBag
3 participants